Skip to content

Conversation

@dchigarev
Copy link
Contributor

@dchigarev dchigarev commented Dec 12, 2024

The PR implements logic that removes linalg.broadcast ops that do not perform any actual broadcasting. Example:

%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out) // doesn't make any sense; can be removed

Why not support all broadcast cases?
A proper lowering for broadcast can be tricky, since xegpu only supports 2D memrefs. Broadcast is always a shape-expanding operation, so there's always at least one operand that is not 2D.

Comment on lines +1315 to +1325
// Checks whether the given linalgOp operand is produced by a
// `linalg::BroadcastOp` that can be replaced by a simple subview
// (for example broadcast: 7x128 -> 1x1x7x128) and ensures that
// the broadcast result is only used by linalgOp in question.
//
// If a valid `linalg::BroadcastOp` is found, the function removes it
// and returns the operand of the `linalg::BroadcastOp` as the new
// linalgOp operand. Otherwise returns the original operand.
static Value findAndReplaceBroadcast(linalg::LinalgOp linalgOp,
size_t operandIdx,
PatternRewriter &rewriter) {
Copy link
Contributor Author

@dchigarev dchigarev Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may ask: why can't we just lower every such linalg.broadcast into something like memref.expand_shape via a separate pattern, instead of doing this "find a broadcast that produces an operand of a linalg-op that we're already lowering to xegpu" quest?

The problem is that the memref-to-spirv pass supports a very limited set of memref ops that can be lowered. It's basically only memref.subview that is supported and we can't expand memref shapes with it. So we can't just replace linalg.broadcast with memref.expand_shape since our pipeline shall fail then:

// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.add ins(%out, ...)

// --------- after LinalgToXeGPU

// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline

// ElementWiseToXeGPUPattern:
%out_squeeze = memref.subview %out : memref<1x7x128> to memref<7x128>
%desc = xegpu.create_tensor_desc %out_squeeze 
...

And although a human eye can see here, that the memref.expand_shape + memref.subview can be eliminated, none of the upstream passes can do that. Even if the expand_shape-subview-merger pass existed, we still could not guarantee, that the memref.expand_shape is always followed by a rank-reducing memref.subview that it can be merged with. Example:

// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.trickyOp ins(%out, ...)

// --------- after LinalgToXeGPU

// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline

// 'linalg.trickyOp' is not supported by LinalgToXeGPU pass
// no rank-reducing memref.subview to merge 'expand_shape' with
linalg.trickyOp ins(%out, ...)
...

// --------- after LinalgToLoops
// BroadcastToExpandShapePattern:
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.expand_shape %inp : memref<7x128> to memref<1x7x128> // <-- this will crash our pipeline

for {
   for {
      for {
          %outScalar = memref.load %out
          arith.trickyOp %outScalar
          ...
       }
   }
}
...

So the only option we're left with is to only "lower" linalg.broadcast when it produces an operand of a linalgOp that we're lowering to xegpu right now, and only do so by simply erasing broadcastOp and forwarding its input to the input of the linalgOp in question. Example:

// --------- before LinalgToXeGPU
%inp = memref.alloc() : memref<7x128xf16>
%out = memref.alloc() : memref<1x7x128xf16>
linalg.broadcast ins(%inp) out(%out)
linalg.add ins(%out, ...)

// --------- after LinalgToXeGPU
// ElementWiseToXeGPUPattern:
%inp = memref.alloc() : memref<7x128xf16>
%desc = xegpu.create_tensor_desc %inp
...


pm.addPass(createDecomposeTensorOperation());
pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion());
pm.addPass(createCanonicalizerPass());
Copy link
Contributor Author

@dchigarev dchigarev Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do the 'cleaning' right after the tiling. Otherwise the bufferization pass may produce memref.cast ops that can not be lowered by memref-to-spirv

@dchigarev dchigarev marked this pull request as ready for review December 12, 2024 13:15
Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we are just squeezing and expanding shapes back and forth. I wonder if we could just avoid creating the broadcast in the first place. That'll have to go into the original module creation and I see how it is tricky when we perform the conversion one operator at a time (e.g., given two elementwise ops, the first is converted to match the output shape with a broadcast; it only becomes clear it is redundant when we have the complete module).

Let's test this out, it's done high-enough so to not change the overall behavior too much.

Signed-off-by: dchigarev <[email protected]>
@dchigarev dchigarev merged commit 34fe67e into intel:main Dec 13, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants